import torch
from torch.optim import Optimizer
from typing import Any, Dict


class UGM(Optimizer):
    __version__ = '1.0.0'

    def __init__(self, params, l0=1, eps=1e-6, max_iters=100):
        """
        Implements Universal Gradient Method (UGM) as a PyTorch Optimizer.

        Args:
            params: Iterable of parameters to optimize or dicts defining parameter groups.
            l0: Initial step size.
            eps: Epsilon value used in the UGM condition.
            max_iters: Maximum number of iterations for the line search.
        """
        self.__class__.__name__ = "UGM"
        if l0 <= 0:
            raise ValueError(f"Invalid value for rho: {l0}. Must be positive.")
        if max_iters <= 0:
            raise ValueError(f"max_iters must be positive. Got: {max_iters}.")

        defaults = dict(eps=eps, max_iters=max_iters)
        super(UGM, self).__init__(params, defaults)

        # Save the initial values of parameters
        for group in self.param_groups:
            for param in group['params']:
                if param.requires_grad:
                    # Save the initial parameter values in the state
                    self.state[param] = {
                        'rho': l0,
                    }

    def __setstate__(self, state: Dict[str, Any]) -> None:
        super(UGM, self).__setstate__(state)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Perform a single optimization step.

        Args:
            closure (callable): A closure that reevaluates the model and returns the loss.

        Returns:
            loss: The loss after the optimization step.
        """
        if closure is None:
            raise ValueError("A closure must be provided to reevaluate the model.")

        loss = closure()
        for group in self.param_groups:
            eps = group['eps']
            max_iters = group['max_iters']

            for param in group['params']:
                grad = param.grad

                rho = self.state[param]['rho']

                # Perform line search for step size
                rho_star = rho
                original_param = param.clone()
                for _ in range(max_iters):
                    new_param = original_param - (1 / rho_star) * grad
                    param.copy_(new_param)
                    new_loss = closure()

                    # Check UGM condition
                    if new_loss <= loss + (grad @ (new_param - param)) + (rho_star / 2) * (torch.norm(new_param - param) ** 2) + (eps / 2):
                        break

                    rho_star *= 2

                self.state[param]['rho'] = (rho_star / 2)

        return loss

    def has_d_estimator(self):
        return False

